{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# {func}`~tvm.target.generic_func`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/read\n"
     ]
    }
   ],
   "source": [
    "%cd ..\n",
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n",
      "4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm/python/tvm/target/target.py:446: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.\n",
      "  warnings.warn(\"Try specifying cuda arch by adding 'arch=sm_xx' to your target.\")\n",
      "[20:02:00] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n"
     ]
    }
   ],
   "source": [
    "import tvm\n",
    "# wrap function as target generic\n",
    "@tvm.target.generic_func\n",
    "def my_func(a):\n",
    "    return a + 1\n",
    "# register specialization of my_func under target cuda\n",
    "@my_func.register(\"cuda\")\n",
    "def my_func_cuda(a):\n",
    "    return a + 2\n",
    "# displays 3, because my_func is called\n",
    "print(my_func(2))\n",
    "# displays 4, because my_func_cuda is called\n",
    "with tvm.target.cuda():\n",
    "    print(my_func(2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import pytest\n",
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm.target import Target, arm_cpu, bifrost, cuda, intel_graphics, mali, rocm, vta\n",
    "\n",
    "\n",
    "@tvm.target.generic_func\n",
    "def mygeneric(data):\n",
    "    # default generic function\n",
    "    return data + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@mygeneric.register([\"cuda\", \"gpu\"])\n",
    "def cuda_func(data):\n",
    "    return data + 2\n",
    "\n",
    "\n",
    "@mygeneric.register(\"rocm\")\n",
    "def rocm_func(data):\n",
    "    return data + 3\n",
    "\n",
    "\n",
    "@mygeneric.register(\"cpu\")\n",
    "def rocm_func(data):\n",
    "    return data + 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "所有目标设备类型的一致性验证："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ROCm not detected, using default gfx900\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n",
      "[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:186: Warning: Unable to detect CUDA version, default to \"-mcpu=sm_50\" instead\n",
      "[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:224: Warning: Unable to detect ROCm version, assuming >= 3.5\n"
     ]
    }
   ],
   "source": [
    "all_targets = [tvm.target.Target(t) for t in tvm.target.Target.list_kinds()]\n",
    "\n",
    "for tgt in all_targets:\n",
    "    # skip targets with hooks or otherwise intended to be used with external codegen\n",
    "    relay_to_tir = tgt.get_kind_attr(\"RelayToTIR\")\n",
    "    tir_to_runtime = tgt.get_kind_attr(\"TIRToRuntime\")\n",
    "    is_external_codegen = tgt.get_kind_attr(\"is_external_codegen\")\n",
    "    if relay_to_tir is not None or tir_to_runtime is not None or is_external_codegen:\n",
    "        continue\n",
    "\n",
    "    if tgt.kind.name not in tvm._ffi.runtime_ctypes.Device.STR2MASK:\n",
    "        raise KeyError(\"Cannot find target kind: %s in Device.STR2MASK\" % tgt.kind.name)\n",
    "\n",
    "    assert (\n",
    "        tgt.get_target_device_type() == tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_target_dispatch():\n",
    "    with tvm.target.cuda():\n",
    "        assert mygeneric(1) == 3\n",
    "        assert mygeneric.get_packed_func()(1) == 3\n",
    "\n",
    "    with tvm.target.rocm():\n",
    "        assert mygeneric(1) == 4\n",
    "        assert mygeneric.get_packed_func()(1) == 4\n",
    "\n",
    "    with tvm.target.Target(\"cuda\"):\n",
    "        assert mygeneric(1) == 3\n",
    "        assert mygeneric.get_packed_func()(1) == 3\n",
    "\n",
    "    with tvm.target.arm_cpu():\n",
    "        assert mygeneric(1) == 11\n",
    "        assert mygeneric.get_packed_func()(1) == 11\n",
    "\n",
    "    with tvm.target.Target(\"metal\"):\n",
    "        assert mygeneric(1) == 3\n",
    "        assert mygeneric.get_packed_func()(1) == 3\n",
    "\n",
    "    assert tvm.target.Target.current() is None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ROCm not detected, using default gfx900\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm/python/tvm/target/target.py:446: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.\n",
      "  warnings.warn(\"Try specifying cuda arch by adding 'arch=sm_xx' to your target.\")\n",
      "[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n",
      "[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:224: Warning: Unable to detect ROCm version, assuming >= 3.5\n",
      "[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n"
     ]
    }
   ],
   "source": [
    "test_target_dispatch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
